Skip to content

Fix TrainableBilateralFilter 3D input validation (#7444)#8729

Open
getrichthroughcode wants to merge 3 commits intoProject-MONAI:devfrom
getrichthroughcode:fix/trainable-bilateral-filter-7444
Open

Fix TrainableBilateralFilter 3D input validation (#7444)#8729
getrichthroughcode wants to merge 3 commits intoProject-MONAI:devfrom
getrichthroughcode:fix/trainable-bilateral-filter-7444

Conversation

@getrichthroughcode
Copy link

  • Fix dimension comparison to use spatial dims instead of total dims
  • Add validation for minimum input dimensions
  • Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma)
  • Move spatial dimension validation before unsqueeze operations

The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected.

Fixes #7444

Description

This PR fixes a validation bug in TrainableBilateralFilter that incorrectly rejected valid 3D inputs with shape (B, C, H, W, D).

Root Cause: The forward() method compared self.len_spatial_sigma (spatial dimensions = 3) with len(input_tensor.shape) (total dimensions = 5), causing a dimension mismatch error for valid inputs.

Solution: Calculate spatial_dims = len(input_tensor.shape) - 2 to exclude batch and channel dimensions, then compare against self.len_spatial_sigma.

Example of fixed behavior:

# Previously failed, now works
bf = TrainableBilateralFilter([1.0, 1.0, 1.0], 1.0)
x = torch.randn(1, 1, 10, 10, 10)  # (B, C, H, W, D)
out = bf(x)  # Success!

This fix also improves error messages and adds validation for inputs with insufficient dimensions.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Notes on Testing

The existing unit tests for TrainableBilateralFilter (24 tests) require the C++ extension and were skipped locally (expected behavior with @skip_if_no_cpp_extension decorator). These tests will run automatically in CI.

I verified the fix logic with custom local tests for 1D, 2D, and 3D cases (see examples in description above).

Linting and code formatting checks passed:

./runtests.sh --autofix     # Passed
./runtests.sh --codeformat  # Passed

No new tests were added as the existing 24 unit tests already cover the behavior. No docstring or documentation changes were needed as this is purely a bug fix in validation logic.

- Fix dimension comparison to use spatial dims instead of total dims
- Add validation for minimum input dimensions
- Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma)
- Move spatial dimension validation before unsqueeze operations

The forward() method was incorrectly comparing self.len_spatial_sigma
(number of spatial dimensions) with len(input_tensor.shape) (total
dimensions including batch and channel), causing valid 3D inputs to
be rejected.

Fixes Project-MONAI#7444

Signed-off-by: Abdoulaye Diallo <abdoulayediallo338@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 1, 2026

📝 Walkthrough

Walkthrough

Added input dimensionality validation to TrainableBilateralFilter and TrainableJointBilateralFilter requiring at least 3 tensor dimensions. Replaced branching on total input length with computation of spatial_dims = len(input) - 2 and used spatial_dims for 1D/2D handling, unsqueeze/squeeze operations, and spatial-sigma consistency checks. Error messages and spacing were adjusted to reflect the updated dimensionality checks. No public class or method signatures were changed.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly summarizes the main change: fixing 3D input validation in TrainableBilateralFilter, directly addressing the issue.
Description check ✅ Passed Description includes all required sections: fixes reference, detailed description with root cause and solution, types of changes, and testing notes. Comprehensive and well-structured.
Linked Issues check ✅ Passed Changes directly address issue #7444: fix spatial-dimension validation to accept 3D inputs (B,C,H,W,D), correct dimension comparison logic, add input validation, and improve error messages.
Out of Scope Changes check ✅ Passed All changes are scoped to fixing the dimension validation bug in TrainableBilateralFilter and TrainableJointBilateralFilter. No unrelated modifications present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/layers/filtering.py (1)

406-430: ⚠️ Potential issue | 🟠 Major

TrainableJointBilateralFilter.forward() not updated with the same fix.

This method still uses len_input directly instead of computing spatial_dims = len_input - 2. It will reject valid 3D inputs just like the original bug in TrainableBilateralFilter. Also missing the minimum dimension validation added to the other class.

Proposed fix
     def forward(self, input_tensor, guidance_tensor):
+        if len(input_tensor.shape) < 3:
+            raise ValueError(
+                f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
+            )
         if input_tensor.shape[1] != 1:
             raise ValueError(
                 f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
                 "Please use multiple parallel filter layers if you want "
                 "to filter multiple channels."
             )
         if input_tensor.shape != guidance_tensor.shape:
             raise ValueError(
                 "Shape of input image must equal shape of guidance image."
                 f"Got {input_tensor.shape} and {guidance_tensor.shape}."
             )

         len_input = len(input_tensor.shape)
+        spatial_dims = len_input - 2

         # C++ extension so far only supports 5-dim inputs.
-        if len_input == 3:
+        if spatial_dims == 1:
             input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
             guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4)
-        elif len_input == 4:
+        elif spatial_dims == 2:
             input_tensor = input_tensor.unsqueeze(4)
             guidance_tensor = guidance_tensor.unsqueeze(4)

-        if self.len_spatial_sigma != len_input:
-            raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
+        if self.len_spatial_sigma != spatial_dims:
+            raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).")

         prediction = TrainableJointBilateralFilterFunction.apply(
             input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
         )

         # Make sure to return tensor of the same shape as the input.
-        if len_input == 3:
+        if spatial_dims == 1:
             prediction = prediction.squeeze(4).squeeze(3)
-        elif len_input == 4:
+        elif spatial_dims == 2:
             prediction = prediction.squeeze(4)

         return prediction
🤖 Fix all issues with AI agents
In `@monai/networks/layers/filtering.py`:
- Around line 223-225: The error message uses self.len_spatial_sigma which is
not assigned in the branch; fix by referencing the actual expected spatial
dimension attribute or ensuring self.len_spatial_sigma is initialized before
this check: either assign self.len_spatial_sigma = self.spatial_ndim (or the
class's existing spatial-dimension attribute) earlier in the initializer, or
change the ValueError message to use the computed expected dimension (e.g.,
self.spatial_ndim or len(self.spatial_shape)) instead of self.len_spatial_sigma
so the attribute is defined when raising the error in the spatial_sigma
validation.
- Around line 395-398: The else branch references an undefined attribute
self.len_spatial_sigma; fix it by using a defined value (e.g., compute
len_spatial = len(self.spatial_sigma) or use self.spatial_ndim) when building
the error message in the failing branch of the initializer (same place as
TrainableBilateralFilter.__init__). Replace self.len_spatial_sigma with the
actual computed length (len(self.spatial_sigma) or self.spatial_ndim) so the
ValueError message prints a valid expected-dimension value.

Comment on lines 395 to 398
else:
raise ValueError(
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same bug: self.len_spatial_sigma undefined in else branch.

Identical issue as TrainableBilateralFilter.__init__.

Proposed fix
         else:
             raise ValueError(
-                f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
+                f"len(spatial_sigma) must be 1, 2, or 3, got {len(spatial_sigma)}."
             )
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 396-398: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@monai/networks/layers/filtering.py` around lines 395 - 398, The else branch
references an undefined attribute self.len_spatial_sigma; fix it by using a
defined value (e.g., compute len_spatial = len(self.spatial_sigma) or use
self.spatial_ndim) when building the error message in the failing branch of the
initializer (same place as TrainableBilateralFilter.__init__). Replace
self.len_spatial_sigma with the actual computed length (len(self.spatial_sigma)
or self.spatial_ndim) so the ValueError message prints a valid
expected-dimension value.

@ericspod
Copy link
Member

ericspod commented Mar 1, 2026

Hi @getrichthroughcode thanks for the contribution, please have a look at the issues coderabbit has commented on and we can then review.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/layers/filtering.py (1)

233-267: ⚠️ Potential issue | 🟠 Major

Add module-wrapper regression tests for the new validation logic.

Current referenced tests exercise TrainableBilateralFilterFunction.apply() directly, while the changed logic is in TrainableBilateralFilter.forward() and TrainableJointBilateralFilter.forward(). Please add tests for: valid 1D/2D/3D wrapper inputs, ndim < 3, len(spatial_sigma) mismatch, and joint input/guidance shape mismatch.

Suggested test additions
+ # tests/networks/layers/filtering/test_trainable_bilateral.py
+ def test_trainable_bilateral_wrapper_accepts_3d_shape():
+     layer = TrainableBilateralFilter(spatial_sigma=(1.0, 1.0, 1.0), color_sigma=0.2)
+     x = torch.randn(1, 1, 10, 10, 10, dtype=torch.double)
+     y = layer(x)
+     assert y.shape == x.shape
+
+ def test_trainable_bilateral_wrapper_rejects_rank_lt_3():
+     layer = TrainableBilateralFilter(spatial_sigma=(1.0,), color_sigma=0.2)
+     with pytest.raises(ValueError):
+         layer(torch.randn(1, 1))
+
+ # tests/networks/layers/filtering/test_trainable_joint_bilateral.py
+ def test_trainable_joint_wrapper_rejects_shape_mismatch():
+     layer = TrainableJointBilateralFilter(spatial_sigma=(1.0, 1.0, 1.0), color_sigma=0.2)
+     x = torch.randn(1, 1, 10, 10, 10, dtype=torch.double)
+     g = torch.randn(1, 1, 10, 10, 9, dtype=torch.double)
+     with pytest.raises(ValueError):
+         layer(x, g)
As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

Also applies to: 406-447

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/layers/filtering.py` around lines 233 - 267, Add module-level
unit tests that call the wrapper classes (use module(input) or
TrainableBilateralFilter.forward/TrainableJointBilateralFilter.forward) instead
of invoking TrainableBilateralFilterFunction.apply() directly: (1) validate
successful forward passes for valid 1D, 2D, and 3D inputs (batch, channel=1,
spatial dims) matching len(spatial_sigma); (2) assert a ValueError is raised
when ndim < 3; (3) assert a ValueError is raised when len(spatial_sigma) does
not match spatial dims (triggering the check in
TrainableBilateralFilter.forward); and (4) for
TrainableJointBilateralFilter.forward add tests that assert a shape-mismatch
between input and guidance tensors raises the expected error. Ensure tests
construct modules with differing len_spatial_sigma and call the module (not the
C++ function) so the new validation logic in TrainableBilateralFilter.forward
and TrainableJointBilateralFilter.forward is covered.
🧹 Nitpick comments (1)
monai/networks/layers/filtering.py (1)

234-237: Document the new ValueError branches in Google-style Raises sections.

Both modified forward methods now enforce additional input validation but do not document raised exceptions.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

Also applies to: 254-255, 407-410, 434-435

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/layers/filtering.py` around lines 234 - 237, Add Google-style
"Raises" entries to the docstrings for the forward methods in
monai/networks/layers/filtering.py that now validate input dimensions: for each
forward (and any other modified methods referenced around the changed ranges)
add a "Raises" section documenting ValueError with a short sentence like
"ValueError: if input_tensor has fewer than 3 dimensions (batch, channel,
*spatial_dims)" (and similarly for other checks introduced at the other modified
locations). Update the docstring for each function/method name forward (and any
other functions showing new validation at the referenced ranges) to include the
Raises section describing the exact condition that triggers the ValueError so it
follows Google-style docstrings.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@monai/networks/layers/filtering.py`:
- Around line 233-267: Add module-level unit tests that call the wrapper classes
(use module(input) or
TrainableBilateralFilter.forward/TrainableJointBilateralFilter.forward) instead
of invoking TrainableBilateralFilterFunction.apply() directly: (1) validate
successful forward passes for valid 1D, 2D, and 3D inputs (batch, channel=1,
spatial dims) matching len(spatial_sigma); (2) assert a ValueError is raised
when ndim < 3; (3) assert a ValueError is raised when len(spatial_sigma) does
not match spatial dims (triggering the check in
TrainableBilateralFilter.forward); and (4) for
TrainableJointBilateralFilter.forward add tests that assert a shape-mismatch
between input and guidance tensors raises the expected error. Ensure tests
construct modules with differing len_spatial_sigma and call the module (not the
C++ function) so the new validation logic in TrainableBilateralFilter.forward
and TrainableJointBilateralFilter.forward is covered.

---

Nitpick comments:
In `@monai/networks/layers/filtering.py`:
- Around line 234-237: Add Google-style "Raises" entries to the docstrings for
the forward methods in monai/networks/layers/filtering.py that now validate
input dimensions: for each forward (and any other modified methods referenced
around the changed ranges) add a "Raises" section documenting ValueError with a
short sentence like "ValueError: if input_tensor has fewer than 3 dimensions
(batch, channel, *spatial_dims)" (and similarly for other checks introduced at
the other modified locations). Update the docstring for each function/method
name forward (and any other functions showing new validation at the referenced
ranges) to include the Raises section describing the exact condition that
triggers the ValueError so it follows Google-style docstrings.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f7851bf7-ad86-46fd-95e4-9cbb0c48f191

📥 Commits

Reviewing files that changed from the base of the PR and between 4924aa6 and 8264cac.

📒 Files selected for processing (1)
  • monai/networks/layers/filtering.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/layers/filtering.py (1)

233-267: ⚠️ Potential issue | 🟠 Major

Add regression tests for the new forward validation branches.

This PR changes validation/control flow in both forward methods (Line 234 onward and Line 407 onward), but no test updates are shown here. Please add or point to tests covering: valid 1D/2D/3D inputs, <3 rank rejection, and len_spatial_sigma mismatch rejection in both classes.

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

Also applies to: 406-447

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/layers/filtering.py` around lines 233 - 267, The forward()
validation branches now added need unit tests: add tests exercising
TrainableBilateralFilter.forward and the other class's forward (the second
forward around lines 406-447) for valid 1D/2D/3D inputs returning same shaped
tensors, for inputs with rank < 3 raising ValueError, for inputs with channel
dimension != 1 raising ValueError, and for cases where spatial_dims !=
self.len_spatial_sigma raising ValueError; implement tests by constructing small
tensors of appropriate shapes, calling the respective forward methods (or the
module forward via model(input_tensor)), and asserting output shapes or that
ValueError is raised, referencing the methods forward,
TrainableBilateralFilterFunction.apply, and the attribute len_spatial_sigma to
locate code under test.
🧹 Nitpick comments (1)
monai/networks/layers/filtering.py (1)

234-237: Document the new ValueError conditions in method docstrings.

Line 234 / Line 407 and Line 255 / Line 435 add explicit exceptions. Please add Google-style Raises: details for these forward methods.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

Also applies to: 255-255, 407-410, 435-435

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/networks/layers/filtering.py` around lines 234 - 237, The new explicit
ValueError checks in the forward methods (function name: forward) need to be
documented: update the Google-style docstrings for the forward methods in
monai/networks/layers/filtering.py to add a Raises: section that describes the
ValueError conditions (e.g., when input tensor has fewer than 3 dimensions or
when other explicit checks fail), include the exception type and a brief
description matching the raised message, and ensure both forward implementations
(the ones around the added checks) mention these Raises entries.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@monai/networks/layers/filtering.py`:
- Around line 233-267: The forward() validation branches now added need unit
tests: add tests exercising TrainableBilateralFilter.forward and the other
class's forward (the second forward around lines 406-447) for valid 1D/2D/3D
inputs returning same shaped tensors, for inputs with rank < 3 raising
ValueError, for inputs with channel dimension != 1 raising ValueError, and for
cases where spatial_dims != self.len_spatial_sigma raising ValueError; implement
tests by constructing small tensors of appropriate shapes, calling the
respective forward methods (or the module forward via model(input_tensor)), and
asserting output shapes or that ValueError is raised, referencing the methods
forward, TrainableBilateralFilterFunction.apply, and the attribute
len_spatial_sigma to locate code under test.

---

Nitpick comments:
In `@monai/networks/layers/filtering.py`:
- Around line 234-237: The new explicit ValueError checks in the forward methods
(function name: forward) need to be documented: update the Google-style
docstrings for the forward methods in monai/networks/layers/filtering.py to add
a Raises: section that describes the ValueError conditions (e.g., when input
tensor has fewer than 3 dimensions or when other explicit checks fail), include
the exception type and a brief description matching the raised message, and
ensure both forward implementations (the ones around the added checks) mention
these Raises entries.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c9e8460d-3a72-4376-bd8f-39abdacf3dc9

📥 Commits

Reviewing files that changed from the base of the PR and between 8264cac and a8c7a0e.

📒 Files selected for processing (1)
  • monai/networks/layers/filtering.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

can't use TrainableBilateralFilter for 3d image

2 participants